# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import random
import re
import time
import traceback
from collections import defaultdict
from typing import Any, List, Optional, Union

import datasets
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from x.io import JsonWrap
from x.qwen.data_format import slim_messages

import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask

logger = logging.getLogger(__name__)

import copy


def message_translate(messages, to_format='dashscope'):
    if to_format == 'dashscope':
        return messages
    
    if to_format == 'openai':
        messages = copy.deepcopy(messages)
        for msg in messages:
            if isinstance(msg['content'], str):
                msg['content'] = [msg['content']]
            new_contents = []
            for content in msg['content']:
                if  isinstance(content, str):
                    new_contents.append({"type": "text", 'text': content})
                elif 'text' in content:
                    new_contents.append({"type": "text", 'text': content['text']})
                elif 'image' in content:
                    img_url = content['image']
                    if img_url.startswith('/'):
                        img_url = f"file://{img_url}"
                    new_contents.append({"type": "image_url", "image_url": {"url": img_url}})
                else:
                    raise NotImplementedError
            msg['content'] = new_contents
        return messages
    if to_format == 'qwen':
        messages = copy.deepcopy(messages)
        for msg in messages:
            if isinstance(msg['content'], str):
                msg['content'] = [msg['content']]
            new_contents = []
            for content in msg['content']:
                if  isinstance(content, str):
                    new_contents.append({"type": "text", 'text': content})
                elif 'text' in content:
                    new_contents.append({"type": "text", 'text': content['text']})
                elif 'image' in content:
                    new_contents.append({"type": "image", "image": content['image']})
                else:
                    raise NotImplementedError
            msg['content'] = new_contents
        return messages

def collate_fn(data_list: list[dict]) -> dict:
    """
    Collate a batch of sample dicts into batched tensors and arrays.

    Args:
        data_list: List of dicts mapping feature names to torch.Tensor or other values.

    Returns:
        Dict where tensor entries are stacked into a torch.Tensor of shape
        (batch_size, *dims) and non-tensor entries are converted to
        np.ndarray of dtype object with shape (batch_size,).
    """
    tensors = defaultdict(list)
    non_tensors = defaultdict(list)

    for data in data_list:
        if isinstance(data, list):
            # dynamic batch size
            pass
        else:
            # original 
            data = [data]
        # will result different batchsize, so we should build micro batch after collate_fn
        for _ in data:
            for key, val in _.items():
                if isinstance(val, torch.Tensor):
                    tensors[key].append(val)
                else:
                    non_tensors[key].append(val)
    for key, val in tensors.items():
        tensors[key] = torch.stack(val, dim=0)

    for key, val in non_tensors.items():
        non_tensors[key] = np.array(val, dtype=object)

    return {**tensors, **non_tensors}


class RLHFDataset(Dataset):
    """
    Load and preprocess RLHF data from Parquet files.

    - Caches files locally.
    - Reads into a HuggingFace Dataset and tokenizes prompts.
    - Optionally handles images/videos via a ProcessorMixin.
    - Filters prompts over a max length.
    - Supports resuming from checkpoints.

    Args:
        data_files (str or list): Path(s) to Parquet file(s).
        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
    """

    def __init__(
        self,
        data_files: Union[str, List[str]],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
    ):
        if not isinstance(data_files, (List, ListConfig)):
            data_files = [data_files]

        self.data_files = copy.deepcopy(data_files)
        self.original_data_files = copy.deepcopy(data_files)  # use for resume
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config

        self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
        self.prompt_key = config.get("prompt_key", "prompt")
        self.image_key = config.get("image_key", "images")
        self.video_key = config.get("video_key", "videos")
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.return_raw_chat = config.get("return_raw_chat", False)
        self.return_full_prompt = config.get("return_full_prompt", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)

        self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
        self.num_workers = min(self.num_workers, os.cpu_count())
        self.use_shm = config.get("use_shm", False)
        self.chat_template_func = config.get("chat_template_func", None)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)
        self.filter_prompts = config.get("filter_prompts", True)
        self.serialize_dataset = False
        self._download()
        self._read_files_and_tokenize()

    def _download(self, use_origin_parquet=False):
        from verl.utils.fs import copy_to_local

        data_files = self.data_files if not use_origin_parquet else self.original_data_files
        for i, parquet_file in enumerate(data_files):
            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)

    def _read_files_and_tokenize(self):
        dataframes = []
        for parquet_file in self.data_files:
            # read parquet files and cache
            dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
            dataframes.append(dataframe)
        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

        print(f"dataset len: {len(self.dataframe)}")

        # filter out too long prompts
        if self.filter_overlong_prompts:
            tokenizer = self.tokenizer
            prompt_key = self.prompt_key
            self.dataframe = self.dataframe.filter(
                lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
                num_proc=self.num_workers,
                desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
            )

            print(f"filter dataset len: {len(self.dataframe)}")

    def resume_dataset_state(self):
        self.serialize_dataset = not hasattr(self, "original_data_files")
        # resume dataframe if not it's serialized in data.pt
        if not self.serialize_dataset:
            self._download(use_origin_parquet=True)  # download and resume from original parquet files
            self._read_files_and_tokenize()
        else:
            print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")

    def __len__(self):
        return len(self.dataframe)

    def _build_messages(self, example: dict):
        messages: list = example.pop(self.prompt_key)

        if self.image_key in example or self.video_key in example:
            for message in messages:
                content = message["content"]
                content_list = []
                for segment in re.split("(<image>|<video>)", content):
                    if segment == "<image>":
                        content_list.append({"type": "image"})
                    elif segment == "<video>":
                        content_list.append({"type": "video"})
                    else:
                        content_list.append({"type": "text", "text": segment})

                message["content"] = content_list

        return messages

    def __getitem__(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict: dict = self.dataframe[item]
        messages = self._build_messages(row_dict)
        model_inputs = {}

        if self.processor is not None:
            from verl.utils.dataset.vision_utils import process_image, process_video

            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            multi_modal_data = {}

            images = None
            if self.image_key in row_dict:
                raw_images = [image for image in row_dict.pop(self.image_key)]
                images = [process_image(image) for image in raw_images]
                multi_modal_data["image"] = images

            videos = None
            if self.video_key in row_dict:
                videos = [process_video(video) for video in row_dict.pop(self.video_key)]
                multi_modal_data["video"] = [video.numpy() for video in videos]

            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")

            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

            if "second_per_grid_ts" in model_inputs:
                model_inputs.pop("second_per_grid_ts")

            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
            row_dict["multi_modal_data"] = multi_modal_data
            row_dict["multi_modal_inputs"] = dict(model_inputs)

            # second_per_grid_ts isn't used for training, just for mrope
            row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)

        else:
            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )

        if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]  # (1, 3, seq_len)

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "middle":
                left_half = self.max_prompt_length // 2
                right_half = self.max_prompt_length - left_half
                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages

        # get prompts with chat template
        if self.return_full_prompt:
            row_dict["full_prompts"] = raw_prompt  # array of strings

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs

        resized_width, resized_height = images[-1].size
        width, height = raw_images[-1]['width'], raw_images[-1]['height']
        row_dict['extra_info'] = {}
        row_dict['extra_info']['resized_width'] = resized_width
        row_dict['extra_info']['resized_height'] = resized_height
        row_dict['extra_info']['width'] = width
        row_dict['extra_info']['height'] = height
        return row_dict

    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()

            if "dataframe" in state:
                del state["dataframe"]
            return state

        return self.__dict__.copy()

class Qwen25VLDataset(RLHFDataset):
    def __init__(self, data_files: str | List[str], tokenizer: PreTrainedTokenizer, config: DictConfig, processor: Any | None = None):
        if '@' in data_files and data_files.split('@')[-1].isdigit():
            data_files, max_dataset_length = data_files.split('@')
            max_dataset_length = int(max_dataset_length)
        else:
            max_dataset_length = None
        if isinstance(data_files, str):
            data_files = [data_files]
        
        for data_file in data_files:
            self.dataframe = JsonWrap(data_files, "auto")
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.index_remap = None
        if max_dataset_length is not None:
            self.max_dataset_length = min(max_dataset_length, len(self.dataframe))
            if self.max_dataset_length < len(self.dataframe):
                self.index_remap = list(range(len(self.dataframe)))  # Map indices up to max length
                random.shuffle(self.index_remap)
                self.index_remap = self.index_remap[:self.max_dataset_length]
        else:
            self.max_dataset_length = None
            
        self.max_pixels = 12800*28*28
        self.min_pixels = 4*28*28
        self.num_image_limit = config.get("num_image_limit", 2)
        # unused maybe
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.return_raw_chat = config.get("return_raw_chat", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)

       
    def _remap_index(self, index):
        if self.index_remap is not None:
            index = self.index_remap[index]
        return index

    def __len__(self):
        if self.index_remap:
            return len(self.index_remap)
        return len(self.dataframe)

    def _build_each_inputs(self, row_dict):
        messages = row_dict['messages']
        messages = slim_messages(messages, num_image_limit=self.num_image_limit)
        last_image_ele = None
        for msg in messages:
            for content in msg['content']:
                # Very Important
                if 'image' in content:
                    if 'min_pixels' not in content: # TODO fix bug, respect to the resized height
                        content['min_pixels'] = self.min_pixels
                    if 'max_pixels' not in content:
                        content['max_pixels'] = self.max_pixels
                    last_image_ele = content
        assert messages[-1]['role'] == 'user'

        assert self.processor is not None
        from verl.utils.dataset.vision_utils import (process_image,
                                                     process_video)

        raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        multi_modal_data = {}
        image_inputs, video_inputs = process_vision_info(messages)
        assert 0 < len(image_inputs)<=self.num_image_limit
        
        width, height = last_image_ele['width'], last_image_ele['height']
        resized_width, resized_height = image_inputs[-1].size
        
        model_inputs = self.processor(text=[raw_prompt], images=image_inputs, videos=video_inputs, return_tensors="pt")
        if image_inputs is not None:
            assert sum(round(_.size[0]*_.size[1]/(28*28)) for _ in image_inputs) == (model_inputs['input_ids'] == 151655).sum()

        multi_modal_data = {
            'image': image_inputs
        }
        input_ids = model_inputs.pop("input_ids")
        attention_mask = model_inputs.pop("attention_mask")

        if "second_per_grid_ts" in model_inputs:
            model_inputs.pop("second_per_grid_ts")

        # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
        row_dict["multi_modal_data"] = multi_modal_data
        row_dict["multi_modal_inputs"] = dict(model_inputs)

        # second_per_grid_ts isn't used for training, just for mrope
        row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)


        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )
        assert self.processor.image_processor.__class__.__name__ != "Qwen2_5VLImageProcessor"
        if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]  # (1, 3, seq_len)

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]

        row_dict['reward_model'] = {
            "style": "rule",
            "ground_truth": row_dict['check_options']
        }

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = message_translate(messages, to_format="openai")

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs
        if 'extra_info' not in row_dict:
            row_dict['extra_info'] = {}
        row_dict['extra_info']['resized_width'] = resized_width
        row_dict['extra_info']['resized_height'] = resized_height
        row_dict['extra_info']['width'] = width
        row_dict['extra_info']['height'] = height

        return row_dict

    def _inner_get_item(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        line = self.dataframe[self._remap_index(item)]
        return self._build_each_inputs(line)

    def __getitem__(self, item):
        while True:
            try:
                res = self._inner_get_item(item)
                return res
            except Exception as e:
                traceback.print_exc()
                print('Error line ', self.dataframe[self._remap_index(item)])
                time.sleep(1)
                item = random.randint(0, len(self) - 1)
                continue

class Qwen25VLNoRolloutDataset(RLHFDataset):
    def __init__(self, data_files: str | List[str], tokenizer: PreTrainedTokenizer, config: DictConfig, processor: Any | None = None):
        if '@' in data_files and data_files.split('@')[-1].isdigit():
            data_files, max_dataset_length = data_files.split('@')
            max_dataset_length = int(max_dataset_length)
        else:
            max_dataset_length = None

        self.dataframe = JsonWrap(data_files, "auto")
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.trace_sample_k = config.get("trace_sample_k", 4)
        self.index_remap = None

        if max_dataset_length is not None:
            self.max_dataset_length = min(max_dataset_length, len(self.dataframe))
            if self.max_dataset_length < len(self.dataframe):
                self.index_remap = list(range(len(self.dataframe)))  # Map indices up to max length
                random.shuffle(self.index_remap)
                self.index_remap = self.index_remap[:self.max_dataset_length]
        else:
            self.max_dataset_length = None
            
        self.max_pixels = 12800*28*28
        self.min_pixels = 4*28*28


        # unused maybe
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.return_raw_chat = config.get("return_raw_chat", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)

    def __len__(self, ):
        if self.index_remap:
            return len(self.index_remap)
        return len(self.dataframe)
    def _remap_index(self, index):
        if self.index_remap is not None:
            index = self.index_remap[index]
        return index

    def _build_each_inputs(self, turn, row_dict):
        messages = turn['messages']
        response = turn['response']
        tmp_response_ids = self.tokenizer(response)['input_ids']
        tmp_response_ids = torch.tensor(tmp_response_ids, dtype=torch.long)
        response_ids = torch.ones(self.config.max_response_length, dtype=torch.long)
        assert len(tmp_response_ids) + 1 < self.config.max_response_length # need a eos
        response_ids[:len(tmp_response_ids)] = tmp_response_ids
        row_dict['response'] = turn['response']
        row_dict['response_ids'] = response_ids
        
        # TODO response ids
        row_dict['_advantage'] = turn['_advantage']
        row_dict['_reward'] = turn['_reward']
        
        for msg in messages:
            for content in msg['content']:
                # Very Important
                if 'image' in content:
                    if 'min_pixels' not in content: # TODO fix bug, respect to the resized height
                        content['min_pixels'] = self.min_pixels
                    if 'max_pixels' not in content:
                        content['max_pixels'] = self.max_pixels
        assert messages[-1]['role'] == 'user'

        assert self.processor is not None
        from verl.utils.dataset.vision_utils import (process_image,
                                                     process_video)

        raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        multi_modal_data = {}
        image_inputs, video_inputs = process_vision_info(messages)
      

        model_inputs = self.processor(text=raw_prompt, images=image_inputs, videos=video_inputs, return_tensors="pt")
        if image_inputs is not None:
            assert sum(round(_.size[0]*_.size[1]/(28*28)) for _ in image_inputs) == (model_inputs['input_ids'] == 151655).sum()

        row_dict["multi_modal_data"] = {
            'image': image_inputs
        }
        input_ids = model_inputs.pop("input_ids")
        attention_mask = model_inputs.pop("attention_mask")

        if "second_per_grid_ts" in model_inputs:
            model_inputs.pop("second_per_grid_ts")

        # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
        row_dict["multi_modal_data"] = multi_modal_data
        row_dict["multi_modal_inputs"] = dict(model_inputs)

        # second_per_grid_ts isn't used for training, just for mrope
        row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)


        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )
        assert self.processor.image_processor.__class__.__name__ != "Qwen2_5VLImageProcessor"
        if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]  # (1, 3, seq_len)

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs
        return row_dict

    def _inner_get_item(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict: dict = self.dataframe[self._remap_index(item)]

        turns = row_dict.pop('turns')
        # turns = random.sample(turns, k=min(len(turns), 4))
        turns = stratified_sample(turns, k=min(len(turns), self.trace_sample_k))
        row_dict_list = []
        for turn in turns:
            tmp_row_dict = self._build_each_inputs(turn, copy.deepcopy(row_dict))
            row_dict_list.append(tmp_row_dict)
        return row_dict_list

    def __getitem__(self, item):
        while True:
            try:
                res = self._inner_get_item(item)
                return res
            except Exception as e:
                traceback.print_exc()
                print('Error line ', self.dataframe[self._remap_index(item)])
                time.sleep(1)
                item = random.randint(0, len(self) - 1)
                continue


class TrajDataset(Qwen25VLNoRolloutDataset):
    

    def _inner_get_item(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict: dict = self.dataframe[item]

        return {"line": row_dict}
        # return row_dict_list

    def __getitem__(self, item):
        line = self.dataframe[self._remap_index(item)]
        def fix_line(line):
            for step in line['steps']:
                check_options = copy.deepcopy(step['action_content'])
                if 'annotation' in step:
                    assert step['annotation'] in ["GOOD","HARMFUL","NEUTRAL"]
                    check_options['annotation'] = step['annotation']
                else:
                    check_options['annotation'] = "GOOD"
                if 'bbox' in step['action_content']: # fix bug
                    check_options['candidate_bbox'] = step['action_content']['bbox'] 
                else:
                    check_options['candidate_bbox'] = []
                step['check_options'] = check_options
                
            return line
        return {"line": np.array(fix_line(line), dtype=object)}
    

def stratified_sample(turns, k):
    if k == len(turns):
        return turns
    # 分层
    positive = [t for t in turns if t['_advantage'] >= 0]
    negative = [t for t in turns if t['_advantage'] < 0]
    
    p = len(positive)
    n_p = len(negative)
    total = p + n_p
    
    if total == 0 or k <= 0:
        return []
    
    # 计算期望样本数
    k_p = round(k * p / total) if p > 0 else 0
    k_p = max(0, min(p, k_p))
    k_n = k - k_p
    
    # 调整样本数，确保不超过各层容量
    if k_n > n_p:
        diff = k_n - n_p
        k_n = n_p
        k_p = max(0, min(p, k - k_n))
    elif k_n < 0:
        k_n = 0
        k_p = min(p, k)

    # 从每层中随机抽样
    sampled = []
    if k_p > 0 and positive:
        sampled.extend(random.sample(positive, k_p))
    if k_n > 0 and negative:
        sampled.extend(random.sample(negative, k_n))

    # 如果仍不足，补充样本（不破坏分层结构）
    while len(sampled) < k:
        remaining = [t for t in turns if t not in sampled]
        if not remaining:
            break
        sampled.append(random.choice(remaining))

    return sampled[:k]